import

import torch 
from fastai.vision.all import * 
import cv2 
import fastbook
from fastbook import *
from fastai.vision.widgets import *

data

path=Path('/home/khy/chest_xray/chest_xray') 
path.ls()
(#5) [Path('/home/khy/chest_xray/chest_xray/train'),Path('/home/khy/chest_xray/chest_xray/test'),Path('/home/khy/chest_xray/chest_xray/chest_xray'),Path('/home/khy/chest_xray/chest_xray/__MACOSX'),Path('/home/khy/chest_xray/chest_xray/val')]
files=get_image_files(path)
files
(#11712) [Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0766-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/NORMAL2-IM-1318-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0160-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/NORMAL2-IM-1327-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0489-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0509-0001-0002.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0761-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0416-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/NORMAL2-IM-0566-0001.jpeg'),Path('/home/khy/chest_xray/chest_xray/train/NORMAL/IM-0411-0001.jpeg')...]
dls = ImageDataLoaders.from_folder(path, train='train', valid_pct=0.2, item_tfms=Resize(224))      
dls.vocab
['NORMAL', 'PNEUMONIA']
dls.show_batch(max_n=16)
learn=cnn_learner(dls,resnet34,metrics=error_rate)
net1=learn.model[0]
net2=learn.model[1] 
net2 = torch.nn.Sequential(
    torch.nn.AdaptiveAvgPool2d(output_size=1), 
    torch.nn.Flatten(),
    torch.nn.Linear(512,out_features=2,bias=False))
net=torch.nn.Sequential(net1,net2)
lrnr2=Learner(dls,net,metrics=accuracy) 
lrnr2.fine_tune(200) 
epoch train_loss valid_loss accuracy time
0 0.166842 0.091861 0.967122 00:39
epoch train_loss valid_loss accuracy time
0 0.076691 0.070642 0.973954 00:38
1 0.065596 0.065189 0.976943 00:38
2 0.063810 0.060881 0.977797 00:38
3 0.058133 0.055606 0.979505 00:38
4 0.047295 0.051751 0.982494 00:38
5 0.049507 0.061955 0.975235 00:38
6 0.040383 0.048890 0.982494 00:38
7 0.037072 0.038793 0.985483 00:38
8 0.029895 0.035411 0.988044 00:38
9 0.024122 0.032279 0.988471 00:38
10 0.022319 0.030799 0.990606 00:38
11 0.022883 0.029063 0.990606 00:38
12 0.018799 0.024217 0.993595 00:38
13 0.018655 0.026862 0.991887 00:38
14 0.017203 0.025556 0.991460 00:38
15 0.012168 0.028741 0.991887 00:38
16 0.013291 0.021540 0.991460 00:38
17 0.013113 0.023177 0.993595 00:38
18 0.014589 0.023715 0.993168 00:38
19 0.010889 0.027784 0.992314 00:38
20 0.010598 0.028819 0.992314 00:38
21 0.013652 0.023543 0.993168 00:38
22 0.010165 0.021542 0.993168 00:38
23 0.011329 0.024496 0.992314 00:38
24 0.009473 0.019847 0.992314 00:38
25 0.007470 0.022198 0.990179 00:38
26 0.007615 0.017968 0.995303 00:38
27 0.006131 0.022273 0.995730 00:38
28 0.008292 0.032437 0.992314 00:38
29 0.008912 0.042545 0.988898 00:38
30 0.009870 0.039163 0.988044 00:38
31 0.010967 0.018784 0.992314 00:38
32 0.006510 0.021688 0.991887 00:38
33 0.006636 0.033374 0.991460 00:38
34 0.010336 0.020198 0.993595 00:38
35 0.009317 0.030448 0.991033 00:38
36 0.007046 0.022307 0.993168 00:38
37 0.009590 0.026956 0.990606 00:38
38 0.006269 0.055886 0.985056 00:38
39 0.010013 0.018850 0.994449 00:38
40 0.008058 0.027818 0.993168 00:38
41 0.007327 0.015476 0.993595 00:38
42 0.006886 0.010855 0.997011 00:38
43 0.011692 0.017141 0.997011 00:38
44 0.007462 0.030888 0.990179 00:38
45 0.006464 0.015794 0.992741 00:38
46 0.007760 0.068463 0.984628 00:38
47 0.006637 0.015711 0.993168 00:38
48 0.010105 0.041067 0.988898 00:38
49 0.007672 0.012651 0.996157 00:38
50 0.014199 0.083004 0.974381 00:38
51 0.012289 0.018203 0.993168 00:38
52 0.009026 0.020449 0.994022 00:38
53 0.004553 0.017501 0.993595 00:38
54 0.010326 0.024923 0.991033 00:38
55 0.015319 0.027962 0.992314 00:38
56 0.004357 0.023815 0.994022 00:38
57 0.005287 0.019874 0.992314 00:38
58 0.009573 0.014026 0.995730 00:38
59 0.006735 0.021964 0.993168 00:38
60 0.005811 0.023319 0.990606 00:38
61 0.011406 0.026691 0.992741 00:38
62 0.005277 0.022868 0.994449 00:38
63 0.006119 0.018390 0.994022 00:38
64 0.007875 0.034545 0.994022 00:38
65 0.005800 0.020408 0.994022 00:38
66 0.002680 0.019692 0.994449 00:38
67 0.006419 0.034546 0.991033 00:38
68 0.006348 0.053590 0.986763 00:38
69 0.005590 0.031790 0.993595 00:38
70 0.007865 0.029411 0.994876 00:38
71 0.002760 0.026847 0.993168 00:38
72 0.009839 0.030372 0.992741 00:38
73 0.008680 0.026388 0.992314 00:38
74 0.004330 0.031201 0.992741 00:38
75 0.009632 0.078810 0.984202 00:38
76 0.003771 0.022387 0.992741 00:38
77 0.006113 0.030133 0.992314 00:38
78 0.003496 0.028839 0.995303 00:38
79 0.003018 0.026174 0.994022 00:38
80 0.007461 0.030011 0.993595 00:38
81 0.004392 0.023791 0.994876 00:38
82 0.005972 0.068508 0.987617 00:38
83 0.006191 0.019870 0.996584 00:38
84 0.005330 0.020402 0.996584 00:38
85 0.002982 0.036186 0.993168 00:38
86 0.003956 0.019152 0.994022 00:38
87 0.006709 0.022051 0.994449 00:38
88 0.004887 0.043770 0.991460 00:38
89 0.004027 0.025353 0.993168 00:38
90 0.002959 0.029085 0.992741 00:38
91 0.003077 0.025070 0.993595 00:38
92 0.004699 0.024857 0.992741 00:38
93 0.002660 0.032952 0.995730 00:38
94 0.003100 0.025073 0.994876 00:38
95 0.002563 0.023130 0.994022 00:38
96 0.001407 0.023987 0.995730 00:38
97 0.002879 0.015754 0.996584 00:38
98 0.002273 0.019964 0.995730 00:38
99 0.001539 0.023395 0.994022 00:38
100 0.002776 0.019369 0.997438 00:38
101 0.001925 0.015023 0.996157 00:38
102 0.002006 0.039217 0.991887 00:38
103 0.003615 0.011737 0.997011 00:38
104 0.002477 0.016405 0.995730 00:38
105 0.001914 0.014328 0.997438 00:38
106 0.000848 0.020702 0.995730 00:38
107 0.005377 0.028292 0.994022 00:38
108 0.003150 0.019413 0.996584 00:38
109 0.001558 0.022858 0.995730 00:38
110 0.002981 0.022044 0.995730 00:38
111 0.003152 0.024832 0.993595 00:38
112 0.001988 0.016285 0.995730 00:38
113 0.000533 0.014695 0.995730 00:38
114 0.000902 0.017304 0.995730 00:39
115 0.001843 0.019725 0.995730 00:38
116 0.001038 0.020030 0.995730 00:38
117 0.000729 0.019264 0.994022 00:38
118 0.001277 0.027110 0.994876 00:38
119 0.001734 0.026816 0.993168 00:38
120 0.002050 0.020589 0.995730 00:38
121 0.002221 0.022525 0.995730 00:38
122 0.000572 0.027818 0.993168 00:38
123 0.001051 0.018991 0.994876 00:38
124 0.000295 0.019816 0.994876 00:38
125 0.001252 0.022995 0.995730 00:38
126 0.000770 0.021016 0.994449 00:38
127 0.000683 0.030154 0.994449 00:38
128 0.003303 0.026239 0.995730 00:38
129 0.001704 0.025088 0.994022 00:38
130 0.002516 0.010910 0.996584 00:38
131 0.000699 0.015325 0.996584 00:38
132 0.000870 0.013863 0.996584 00:38
133 0.000663 0.020103 0.995730 00:38
134 0.000980 0.012507 0.996584 00:38
135 0.000181 0.014895 0.995730 00:38
136 0.000645 0.030882 0.994022 00:38
137 0.000258 0.029726 0.994022 00:38
138 0.000154 0.019418 0.995730 00:38
139 0.000699 0.019971 0.995730 00:38
140 0.000355 0.024038 0.994876 00:38
141 0.000170 0.030813 0.994876 00:38
142 0.000657 0.027899 0.994876 00:38
143 0.001425 0.024708 0.995730 00:38
144 0.000381 0.020135 0.994022 00:38
145 0.000152 0.025634 0.994876 00:38
146 0.000075 0.018921 0.994876 00:38
147 0.000226 0.017673 0.994876 00:38
148 0.000224 0.023066 0.996584 00:38
149 0.000632 0.018082 0.994876 00:38
150 0.000625 0.016179 0.996584 00:38
151 0.000080 0.021201 0.994876 00:38
152 0.000068 0.021460 0.994022 00:38
153 0.000112 0.018794 0.995730 00:38
154 0.000080 0.021812 0.994876 00:38
155 0.000040 0.018293 0.995730 00:38
156 0.000171 0.018570 0.997438 00:38
157 0.000175 0.015313 0.996584 00:38
158 0.000464 0.016535 0.996584 00:38
159 0.000109 0.019572 0.996584 00:38
160 0.000062 0.021594 0.996584 00:38
161 0.000064 0.014384 0.996584 00:38
162 0.000014 0.020526 0.996584 00:38
163 0.000028 0.019420 0.995730 00:38
164 0.000042 0.030555 0.994876 00:38
165 0.000080 0.022019 0.996584 00:38
166 0.000079 0.030117 0.994876 00:38
167 0.000038 0.019891 0.996584 00:38
168 0.000027 0.024130 0.996584 00:38
169 0.000017 0.027270 0.995730 00:38
170 0.000032 0.018282 0.995730 00:38
171 0.000062 0.019155 0.996584 00:38
172 0.000059 0.023948 0.995730 00:38
173 0.000011 0.025428 0.995730 00:38
174 0.000011 0.019787 0.995730 00:38
175 0.000018 0.025644 0.995730 00:38
176 0.000185 0.021899 0.995730 00:38
177 0.000056 0.021866 0.995730 00:38
178 0.000061 0.022560 0.995730 00:38
179 0.000019 0.019159 0.995730 00:38
180 0.000009 0.024180 0.995730 00:38
181 0.000030 0.022470 0.995730 00:38
182 0.000007 0.020468 0.995730 00:38
183 0.000049 0.024680 0.995730 00:38
184 0.000009 0.019799 0.994876 00:38
185 0.000026 0.025008 0.995730 00:38
186 0.000028 0.029448 0.995730 00:38
187 0.000161 0.032871 0.995730 00:38
188 0.000334 0.028276 0.995730 00:38
189 0.000033 0.023425 0.995730 00:38
190 0.000012 0.027646 0.995730 00:38
191 0.000012 0.026857 0.995730 00:38
192 0.000120 0.025125 0.995730 00:38
193 0.000014 0.029498 0.995730 00:38
194 0.000010 0.028255 0.995730 00:38
195 0.000098 0.027213 0.995730 00:38
196 0.000031 0.024639 0.995730 00:38
197 0.000021 0.028268 0.995730 00:38
198 0.000005 0.021215 0.995730 00:38
199 0.000010 0.027356 0.995730 00:38

CAM 결과 확인_에폭 200

fig, ax = plt.subplots(5,5) 
k=0 
for i in range(5):
    for j in range(5): 
        x, = first(dls.test_dl([PILImage.create(get_image_files(path)[k])]))
        camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x).squeeze())
        a,b = net(x).tolist()[0]
        normalprob, pneumoniaprob = np.exp(a)/ (np.exp(a)+np.exp(b)) ,  np.exp(b)/ (np.exp(a)+np.exp(b)) 
        if normalprob>pneumoniaprob: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("normal(%s)" % normalprob.round(5))
        else: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("pneumonia(%s)" % pneumoniaprob.round(5))
        k=k+1 
fig.set_figwidth(16)            
fig.set_figheight(16)
fig.tight_layout()
fig, ax = plt.subplots(5,5) 
k=3000 
for i in range(5):
    for j in range(5): 
        x, = first(dls.test_dl([PILImage.create(get_image_files(path)[k])]))
        camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x).squeeze())
        a,b = net(x).tolist()[0]
        normalprob, pneumoniaprob = np.exp(a)/ (np.exp(a)+np.exp(b)) ,  np.exp(b)/ (np.exp(a)+np.exp(b)) 
        if normalprob>pneumoniaprob: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("normal(%s)" % normalprob.round(5))
        else: 
            dls.train.decode((x,))[0].squeeze().show(ax=ax[i][j])
            ax[i][j].imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='magma')
            ax[i][j].set_title("pneumonia(%s)" % pneumoniaprob.round(5))
        k=k+1 
fig.set_figwidth(16)            
fig.set_figheight(16)
fig.tight_layout()

SAMPLE

get_image_files(path)[3021]
Path('/home/khy/chest_xray/chest_xray/train/PNEUMONIA/person12_bacteria_47.jpeg')
img = PILImage.create(get_image_files(path)[3021])
img
x, = first(dls.test_dl([img]))  #이미지 텐서화
x.shape
torch.Size([1, 3, 224, 224])

판단 근거가 강할수록 파란색 $\to$ 보라색 변함

a=net(x.to('cpu')).tolist()[0][0]
b=net(x.to('cpu')).tolist()[0][1]
np.exp(a)/(np.exp(a)+np.exp(b)), np.exp(b)/(np.exp(a)+np.exp(b))
(1.0052419841905753e-22, 1.0)
camimg = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x.to('cpu')).squeeze())
fig, (ax1,ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[0].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
test=camimg[1]-torch.min(camimg[1])
A1=torch.exp(-0.02*test)
A2=1-A1
fig, (ax1, ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A2.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax1.set_title("MODE1 WEIGHT")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A1.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='cool')
ax2.set_title("MODE1 RES WEIGHT")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
X1=np.array(A1.to("cpu").detach(),dtype=np.float32)
Y1=torch.Tensor(cv2.resize(X1,(224,224),interpolation=cv2.INTER_LINEAR))
x1=x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*Y1)
X12=np.array(A2.to("cpu").detach(),dtype=np.float32)
Y12=torch.Tensor(cv2.resize(X12,(224,224),interpolation=cv2.INTER_LINEAR))
x12=x.squeeze().to('cpu')*Y12#-torch.min(x.squeeze().to('cpu')*Y12)
  • 1st CAM 결과를 분리하면 아래와 같음.
fig, (ax1) = plt.subplots(1,1) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)            
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
(x12*0.3).squeeze().show(ax=ax1)  #MODE1
(x1*0.2).squeeze().show(ax=ax2)  #MODE1_res
ax1.set_title("MODE1")
ax2.set_title("MODE1 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
x1=x1.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
Sequential(
  (0): AdaptiveAvgPool2d(output_size=1)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=512, out_features=2, bias=False)
)
camimg1 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x1).squeeze())
  • CAM

    • mode1_res에 CAM 결과 올리기
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
(x1*0.2).squeeze().show(ax=ax1)
ax1.imshow(camimg1[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
(x1*0.2).squeeze().show(ax=ax2)
ax2.imshow(camimg1[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(camimg[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("1ST CAM")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(camimg1[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("2ND CAM")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
x1.shape
torch.Size([1, 3, 224, 224])
a1=net(x1).tolist()[0][0]
b1=net(x1).tolist()[0][1]
np.exp(a1)/(np.exp(a1)+np.exp(b1)), np.exp(b1)/(np.exp(a1)+np.exp(b1))
(1.7377578816154669e-12, 0.9999999999982623)
test1=camimg1[1]-torch.min(camimg1[1])
A3=torch.exp(-0.04*test1)  
A4=1-A3
fig, (ax1, ax2) = plt.subplots(1,2) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.imshow(A3.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='binary')
ax1.set_title("MODE2 WEIGHT")
#
dls.train.decode((x,))[0].squeeze().show(ax=ax2)
ax2.imshow(A4.data.to("cpu").detach(),alpha=0.5,extent=(0,224,224,0),interpolation='bilinear',cmap='binary')
ax2.set_title("MODE2 RES WEIGHT")
#
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
X3=np.array(A3.to("cpu").detach(),dtype=np.float32)
Y3=torch.Tensor(cv2.resize(X3,(224,224),interpolation=cv2.INTER_LINEAR))
x3=x.squeeze().to('cpu')*Y1*Y3-torch.min(x.squeeze().to('cpu')*Y1*Y3)
#x1=x.squeeze().to('cpu')*Y1-torch.min(x.squeeze().to('cpu')*Y1)
X4=np.array(A4.to("cpu").detach(),dtype=np.float32)
Y4=torch.Tensor(cv2.resize(X4,(224,224),interpolation=cv2.INTER_LINEAR))
x4=x.squeeze().to('cpu')*Y12*Y4
#x12=x.squeeze().to('cpu')*Y12
  • 2nd CAM 결과를 분리하면 아래와 같음.
fig, (ax1) = plt.subplots(1,1) 
dls.train.decode((x,))[0].squeeze().show(ax=ax1)
ax1.set_title("ORIGINAL")
fig.set_figwidth(4)            
fig.set_figheight(4)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
(x12*0.5).squeeze().show(ax=ax1)  
(x1*0.3).squeeze().show(ax=ax2)  
ax1.set_title("MODE1")
ax2.set_title("MODE1 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
#
fig, (ax1, ax2) = plt.subplots(1,2) 
(x4*3).squeeze().show(ax=ax1)  
(x3*0.3).squeeze().show(ax=ax2)  
ax1.set_title("MODE2")
ax2.set_title("MODE2 RES")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
x3=x3.reshape(1,3,224,224)
net1.to('cpu')
net2.to('cpu')
Sequential(
  (0): AdaptiveAvgPool2d(output_size=1)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=512, out_features=2, bias=False)
)
camimg2 = torch.einsum('ij,jkl -> ikl', net2[2].weight, net1(x3).squeeze())
  • CAM
fig, (ax1,ax2) = plt.subplots(1,2) 
# 
(x3*0.3).squeeze().show(ax=ax1)
ax1.imshow(camimg2[0].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax1.set_title("NORMAL PART")
#
(x3*0.3).squeeze().show(ax=ax2)
ax2.imshow(camimg2[1].to("cpu").detach(),alpha=0.5,extent=(0,223,223,0),interpolation='bilinear',cmap='cool')
ax2.set_title("DISEASE PART")
fig.set_figwidth(8)            
fig.set_figheight(8)
fig.tight_layout()
a2=net(x3).tolist()[0][0]
b2=net(x3).tolist()[0][1]
np.exp(a2)/(np.exp(a2)+np.exp(b2)), np.exp(b2)/(np.exp(a2)+np.exp(b2))
(5.853030916370377e-16, 0.9999999999999994)